-
Notifications
You must be signed in to change notification settings - Fork 522
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
(fix) Make bias statistics complete for all elements #4496
base: devel
Are you sure you want to change the base?
Conversation
for more information, see https://pre-commit.ci
📝 WalkthroughWalkthroughThe pull request introduces two significant modifications in the DeepMD-kit's PyTorch utility modules. In Changes
Sequence DiagramsequenceDiagram
participant Dataset
participant StatCollector
Dataset->>Dataset: Create element_to_frames mapping
StatCollector->>Dataset: Request missing element types
Dataset-->>StatCollector: Provide frame data for missing elements
StatCollector->>StatCollector: Update statistics with new data
The sequence diagram illustrates the new workflow where the dataset provides a mapping of element types to frames, and the statistics collector can request and integrate data for any missing atomic types, ensuring comprehensive statistical representation. Possibly related PRs
Suggested labels
Suggested reviewers
Tip CodeRabbit's docstrings feature is now available as part of our Early Access Program! Simply use the command Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (1)
deepmd/pt/utils/stat.py (1)
110-119
: Avoid potential index-out-of-range errors when concatenating Tensors.
The code checks “isinstance(sys_stat[key], list) and isinstance(sys_stat[key][0], torch.Tensor)” without verifying that the list is non-empty. This might raise an IndexError if the list is empty. Consider adding a length check before accessing sys_stat[key][0].Possible safe check:
-if isinstance(sys_stat[key], list) and isinstance(sys_stat[key][0], torch.Tensor): +if isinstance(sys_stat[key], list) and len(sys_stat[key]) > 0 \ + and isinstance(sys_stat[key][0], torch.Tensor):
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/pt/utils/dataset.py
(1 hunks)deepmd/pt/utils/stat.py
(1 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
deepmd/pt/utils/dataset.py
44-44: SyntaxError: Expected an indented block after function definition
🔇 Additional comments (3)
deepmd/pt/utils/stat.py (3)
91-96
: Validate 'atype' presence and handle empty arrays gracefully.
While updating the collected elements, consider explicitly checking that each 'atype' array is non-empty. If a dataset yields an empty array, it might bypass your missing-elements logic.
✅ Verification successful
Let me gather more information about how atype
is handled in the codebase.
Let me try again with the correct file type flag for ripgrep.
Based on the code analysis, I can now generate the final response:
Empty arrays are handled safely by numpy.unique and set.update
The code uses numpy.unique()
which safely handles empty arrays by returning an empty array, and set.update()
which safely handles empty iterables. The logic for missing elements works correctly even with empty arrays since an empty array's unique values would simply not add any elements to the set.
🏁 Scripts executed
The following scripts were executed for the analysis:
Error: No shell scripts found
Script:
#!/bin/bash
# Search for patterns related to 'atype' handling and array emptiness checks
rg -A 3 "atype" --type python
# Also look for the specific class/function containing this code
ast-grep --pattern $'def $_(self, $_):
$$$
collected_elements = set()
for sys_stat in lst:
if "atype" in sys_stat:
collected_elements.update($$$)
$$$'
Length of output: 248
Script:
#!/bin/bash
# Search for patterns related to 'atype' handling and array emptiness checks
rg -A 3 "atype" -t py
# Look for similar array handling patterns in the codebase
rg -t py "numpy\.unique|\.update.*numpy" -A 2
# Find the specific file we're analyzing
fd "stat.py" --type f
Length of output: 89317
85-90
: Consider using logging and verifying that all datasets are included.
You are only collecting the set of elements from the first dataset and simply printing them. If your intent is to gather possible elements across all datasets, you may want to iterate over all datasets. Also, using "log.info" would be more consistent than a raw "print" in a library setting.
Possible fix for the print statement:
- print('we want', all_elements)
+ log.info(f"We want elements: {all_elements}")
97-109
: Ensure partial coverage logic for missing elements is intended.
Here, frames are appended until the number of records reaches 'nbatches'. Verify that the short-circuiting with “break” actually meets your coverage needs for missing elements. If your goal is to gather as many frames as possible across all missing elements, you might need a more explicit limit or a different approach.
✅ Verification successful
Let me gather more context about this code to better understand the coverage logic for missing elements.
Let me search for more context about the initialization and usage of nbatches
to better understand the coverage requirements.
Let me search for the initialization of nbatches
to understand its purpose in the coverage logic.
Let me search for more context about how nbatches
is used and its relationship with the dataset statistics.
Based on the gathered context, I can now provide a final response about the coverage logic for missing elements.
The current partial coverage logic is intentional and correct.
The code's behavior is appropriate because:
nbatches
is a parameter that defines the desired number of batches for collecting statistics, used consistently throughout themake_stat_input
function- The break condition
if len(lst[i]['atype']) >= nbatches
ensures each system collects up to the specified number of batches, preventing over-sampling - This is consistent with how regular batches are collected earlier in the function where
numb_batches = min(nbatches, len(dataloaders[i]))
is used - The goal is to gather a statistically significant sample (controlled by nbatches) rather than exhaustively processing all frames containing missing elements
The current implementation correctly balances between gathering sufficient data for statistics while maintaining consistent batch sizes across all elements.
🏁 Scripts executed
The following scripts were executed for the analysis:
Error: No shell scripts found
Script:
#!/bin/bash
# Let's first look at the broader context of this code
rg -B 10 -A 5 "for missing_element in missing_elements" deepmd/pt/utils/stat.py
# Also search for any related tests that might help understand the intended behavior
fd "test.*stat.*py" -x rg -l "missing_element"
# Look for the initialization of nbatches and lst variables
rg -B 5 "nbatches.*=|lst.*=" deepmd/pt/utils/stat.py
Length of output: 1724
Script:
#!/bin/bash
# Search for the function definition and its parameters
ast-grep --pattern 'def $FUNC($PARAMS) {
$$$
nbatches
$$$
}' deepmd/pt/utils/stat.py
# Look for any documentation or comments about nbatches
rg -B 3 -A 3 "nbatches" deepmd/pt/utils/stat.py
# Check if there are any tests that verify this behavior
fd "test.*stat.*py" -x cat {}
Length of output: 66294
Script:
#!/bin/bash
# Look for the function definition and documentation
rg -B 10 "def make_stat_input" deepmd/pt/utils/stat.py
# Search for any tests that verify the behavior with missing elements
rg -B 5 -A 5 "missing_element" deepmd/pt/utils/stat.py
Length of output: 1079
Script:
#!/bin/bash
# Look for the full function implementation and its docstring
rg -B 20 -A 50 "def make_stat_input" deepmd/pt/utils/stat.py
# Search for any tests that verify the statistics collection
fd "test.*stat.*py" -x rg -A 10 "make_stat_input"
# Look for the initialization of lst and how frames are collected
rg -B 5 "lst.*=.*\[\]" deepmd/pt/utils/stat.py
Length of output: 4539
deepmd/pt/utils/dataset.py
Outdated
def _build_element_to_frames(self): | ||
"""Mapping element types to frame indexes""" | ||
element_to_frames = {element: [] for element in range(self._ntypes)} | ||
for frame_idx in range(len(self)): | ||
frame_data = self._data_system.get_item_torch(frame_idx) | ||
|
||
elements = frame_data["atype"] | ||
for element in set(elements): | ||
if len(element_to_frames[element]) < 10: | ||
element_to_frames[element].append(frame_idx) | ||
return element_to_frames | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix indentation error and consider making the max frame limit configurable.
Static analysis flags a SyntaxError likely due to insufficient indentation after defining the method. Ensure this block is indented so that Python recognizes it as part of the method. Also, the hard-coded limit of 10 frames per element may need to be a configurable parameter if you expect variation in your datasets.
Example indentation fix:
- def _build_element_to_frames(self):
- """Mapping element types to frame indexes"""
- element_to_frames = {element: [] for element in range(self._ntypes)}
+ def _build_element_to_frames(self):
+ """Mapping element types to frame indexes"""
+ element_to_frames = {element: [] for element in range(self._ntypes)}
Committable suggestion skipped: line range outside the PR's diff.
🧰 Tools
🪛 Ruff (0.8.2)
44-44: SyntaxError: Expected an indented block after function definition
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (2)
deepmd/pt/utils/stat.py (2)
89-89
: Remove debug print statement.The print statement
print("we want", all_elements)
appears to be debug code that should be removed or replaced with proper logging.- print("we want", all_elements) + log.debug(f"Required elements for statistics: {all_elements}")
97-111
: Optimize nested loops and add error handling.The nested loops for handling missing elements could be optimized, and error handling should be added for invalid frame indices.
for missing_element in missing_elements: for i, dataset in enumerate(datasets): if hasattr(dataset, "element_to_frames"): - frame_indices = dataset.element_to_frames.get( - missing_element, [] - ) + try: + frame_indices = dataset.element_to_frames.get(missing_element, []) + if not frame_indices: + continue + + # Pre-check if we need more frames + if len(lst[i]["atype"]) >= nbatches: + break + + # Process frames in batch + for frame_idx in frame_indices: + frame_data = dataset[frame_idx] + if any(key not in lst[i] for key in frame_data): + lst[i].update({key: [] for key in frame_data if key not in lst[i]}) + for key in frame_data: + lst[i][key].append(frame_data[key]) + if len(lst[i]["atype"]) >= nbatches: + break + except Exception as e: + log.warning(f"Error processing frames for element {missing_element}: {e}") + continue - for frame_idx in frame_indices: - if len(lst[i]["atype"]) >= nbatches: - break - frame_data = dataset[frame_idx] - for key in frame_data: - if key not in lst[i]: - lst[i][key] = [] - lst[i][key].append(frame_data[key])
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
deepmd/pt/utils/dataset.py
(1 hunks)deepmd/pt/utils/stat.py
(1 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
deepmd/pt/utils/dataset.py
44-44: SyntaxError: Expected an indented block after function definition
🔇 Additional comments (2)
deepmd/pt/utils/dataset.py (2)
43-44
:
Fix the indentation error in method definition.
The method definition has incorrect indentation which will cause a SyntaxError. It should be aligned with other class methods.
- def _build_element_to_frames(self):
- """Mapping element types to frame indexes"""
+ def _build_element_to_frames(self):
+ """Mapping element types to frame indexes"""
Likely invalid or redundant comment.
🧰 Tools
🪛 Ruff (0.8.2)
44-44: SyntaxError: Expected an indented block after function definition
43-53
: 🛠️ Refactor suggestion
Make frame limit configurable and enhance documentation.
- The hard-coded limit of 10 frames per element should be configurable.
- The docstring should be more descriptive about the method's purpose and return value.
- def _build_element_to_frames(self):
- """Mapping element types to frame indexes"""
+ def _build_element_to_frames(self, max_frames_per_element: int = 10) -> dict[int, list[int]]:
+ """Build a mapping of element types to their corresponding frame indices.
+
+ Args:
+ max_frames_per_element: Maximum number of frames to store per element type.
+
+ Returns:
+ A dictionary mapping element types (int) to lists of frame indices (list[int])
+ where each element type appears.
+ """
element_to_frames = {element: [] for element in range(self._ntypes)}
for frame_idx in range(len(self)):
frame_data = self._data_system.get_item_torch(frame_idx)
elements = frame_data["atype"]
for element in set(elements):
- if len(element_to_frames[element]) < 10:
+ if len(element_to_frames[element]) < max_frames_per_element:
element_to_frames[element].append(frame_idx)
return element_to_frames
Likely invalid or redundant comment.
🧰 Tools
🪛 Ruff (0.8.2)
44-44: SyntaxError: Expected an indented block after function definition
|
||
all_elements = set() | ||
if datasets and hasattr(datasets[0], "element_to_frames"): | ||
all_elements.update(datasets[0].element_to_frames.keys()) | ||
print("we want", all_elements) | ||
|
||
collected_elements = set() | ||
for sys_stat in lst: | ||
if "atype" in sys_stat: | ||
collected_elements.update(np.unique(sys_stat["atype"].cpu().numpy())) | ||
missing_elements = all_elements - collected_elements | ||
|
||
for missing_element in missing_elements: | ||
for i, dataset in enumerate(datasets): | ||
if hasattr(dataset, "element_to_frames"): | ||
frame_indices = dataset.element_to_frames.get( | ||
missing_element, [] | ||
) | ||
for frame_idx in frame_indices: | ||
if len(lst[i]["atype"]) >= nbatches: | ||
break | ||
frame_data = dataset[frame_idx] | ||
for key in frame_data: | ||
if key not in lst[i]: | ||
lst[i][key] = [] | ||
lst[i][key].append(frame_data[key]) | ||
|
||
collected_elements = set() | ||
for sys_stat in lst: | ||
if "atype" in sys_stat: | ||
collected_elements.update( | ||
np.unique(sys_stat["atype"].cpu().numpy()) | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codebase verification
Based on the code analysis, I can now provide a final response:
Add validation check for missing elements after collection
After the final collection of elements, add a validation check to ensure all required elements were found:
# After the final collected_elements update
if all_elements - collected_elements:
missing = all_elements - collected_elements
warnings.warn(f"Could not find frames containing elements: {missing}")
- Add this check at line 118 in deepmd/pt/utils/stat.py after the final collection loop
- Import warnings module at the top of the file
🔗 Analysis chain
Consider adding validation for collected elements.
After collecting elements, it would be good to verify that all required elements were found and log a warning if any are still missing.
Let me analyze the code further to understand the element collection and validation context.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check if there are any elements that might be consistently missing across datasets
# Search for element type definitions and usages
rg -A 5 "atype|element_to_frames"
Length of output: 84908
Script:
#!/bin/bash
# Search for element validation or error handling related to missing elements
rg -A 5 "missing_elements|element_to_frames|all_elements"
Length of output: 2630
@@ -40,6 +40,18 @@ | |||
b_data["natoms"] = self._natoms_vec | |||
return b_data | |||
|
|||
def _build_element_to_frames(self): |
Check warning
Code scanning / CodeQL
Unreachable code Warning
collected_elements.update( | ||
np.unique(sys_stat["atype"].cpu().numpy()) | ||
) | ||
|
Check notice
Code scanning / CodeQL
Nested loops with same variable Note
Summary by CodeRabbit
New Features
Bug Fixes